// 作者：https://github.com/8891689
// zuc_avx2.c 
#include "zuc_avx2.h"
#include <string.h>
#include <immintrin.h>

/* S-Box 和 D 常量  */

static const uint8_t S0[256] =  {
0x3e,0x72,0x5b,0x47,0xca,0xe0,0x00,0x33,0x04,0xd1,0x54,0x98,0x09,0xb9,0x6d,0xcb,
0x7b,0x1b,0xf9,0x32,0xaf,0x9d,0x6a,0xa5,0xb8,0x2d,0xfc,0x1d,0x08,0x53,0x03,0x90,
0x4d,0x4e,0x84,0x99,0xe4,0xce,0xd9,0x91,0xdd,0xb6,0x85,0x48,0x8b,0x29,0x6e,0xac,
0xcd,0xc1,0xf8,0x1e,0x73,0x43,0x69,0xc6,0xb5,0xbd,0xfd,0x39,0x63,0x20,0xd4,0x38,
0x76,0x7d,0xb2,0xa7,0xcf,0xed,0x57,0xc5,0xf3,0x2c,0xbb,0x14,0x21,0x06,0x55,0x9b,
0xe3,0xef,0x5e,0x31,0x4f,0x7f,0x5a,0xa4,0x0d,0x82,0x51,0x49,0x5f,0xba,0x58,0x1c,
0x4a,0x16,0xd5,0x17,0xa8,0x92,0x24,0x1f,0x8c,0xff,0xd8,0xae,0x2e,0x01,0xd3,0xad,
0x3b,0x4b,0xda,0x46,0xeb,0xc9,0xde,0x9a,0x8f,0x87,0xd7,0x3a,0x80,0x6f,0x2f,0xc8,
0xb1,0xb4,0x37,0xf7,0x0a,0x22,0x13,0x28,0x7c,0xcc,0x3c,0x89,0xc7,0xc3,0x96,0x56,
0x07,0xbf,0x7e,0xf0,0x0b,0x2b,0x97,0x52,0x35,0x41,0x79,0x61,0xa6,0x4c,0x10,0xfe,
0xbc,0x26,0x95,0x88,0x8a,0xb0,0xa3,0xfb,0xc0,0x18,0x94,0xf2,0xe1,0xe5,0xe9,0x5d,
0xd0,0xdc,0x11,0x66,0x64,0x5c,0xec,0x59,0x42,0x75,0x12,0xf5,0x74,0x9c,0xaa,0x23,
0x0e,0x86,0xab,0xbe,0x2a,0x02,0xe7,0x67,0xe6,0x44,0xa2,0x6c,0xc2,0x93,0x9f,0xf1,
0xf6,0xfa,0x36,0xd2,0x50,0x68,0x9e,0x62,0x71,0x15,0x3d,0xd6,0x40,0xc4,0xe2,0x0f,
0x8e,0x83,0x77,0x6b,0x25,0x05,0x3f,0x0c,0x30,0xea,0x70,0xb7,0xa1,0xe8,0xa9,0x65,
0x8d,0x27,0x1a,0xdb,0x81,0xb3,0xa0,0xf4,0x45,0x7a,0x19,0xdf,0xee,0x78,0x34,0x60
};

static const uint8_t S1[256] = {
0x55,0xc2,0x63,0x71,0x3b,0xc8,0x47,0x86,0x9f,0x3c,0xda,0x5b,0x29,0xaa,0xfd,0x77,
0x8c,0xc5,0x94,0x0c,0xa6,0x1a,0x13,0x00,0xe3,0xa8,0x16,0x72,0x40,0xf9,0xf8,0x42,
0x44,0x26,0x68,0x96,0x81,0xd9,0x45,0x3e,0x10,0x76,0xc6,0xa7,0x8b,0x39,0x43,0xe1,
0x3a,0xb5,0x56,0x2a,0xc0,0x6d,0xb3,0x05,0x22,0x66,0xbf,0xdc,0x0b,0xfa,0x62,0x48,
0xdd,0x20,0x11,0x06,0x36,0xc9,0xc1,0xcf,0xf6,0x27,0x52,0xbb,0x69,0xf5,0xd4,0x87,
0x7f,0x84,0x4c,0xd2,0x9c,0x57,0xa4,0xbc,0x4f,0x9a,0xdf,0xfe,0xd6,0x8d,0x7a,0xeb,
0x2b,0x53,0xd8,0x5c,0xa1,0x14,0x17,0xfb,0x23,0xd5,0x7d,0x30,0x67,0x73,0x08,0x09,
0xee,0xb7,0x70,0x3f,0x61,0xb2,0x19,0x8e,0x4e,0xe5,0x4b,0x93,0x8f,0x5d,0xdb,0xa9,
0xad,0xf1,0xae,0x2e,0xcb,0x0d,0xfc,0xf4,0x2d,0x46,0x6e,0x1d,0x97,0xe8,0xd1,0xe9,
0x4d,0x37,0xa5,0x75,0x5e,0x83,0x9e,0xab,0x82,0x9d,0xb9,0x1c,0xe0,0xcd,0x49,0x89,
0x01,0xb6,0xbd,0x58,0x24,0xa2,0x5f,0x38,0x78,0x99,0x15,0x90,0x50,0xb8,0x95,0xe4,
0xd0,0x91,0xc7,0xce,0xed,0x0f,0xb4,0x6f,0xa0,0xcc,0xf0,0x02,0x4a,0x79,0xc3,0xde,
0xa3,0xef,0xea,0x51,0xe6,0x6b,0x18,0xec,0x1b,0x2c,0x80,0xf7,0x74,0xe7,0xff,0x21,
0x5a,0x6a,0x54,0x1e,0x41,0x31,0x92,0x35,0xc4,0x33,0x07,0x0a,0xba,0x7e,0x0e,0x34,
0x88,0xb1,0x98,0x7c,0xf3,0x3d,0x60,0x6c,0x7b,0xca,0xd3,0x1f,0x32,0x65,0x04,0x28,
0x64,0xbe,0x85,0x9b,0x2f,0x59,0x8a,0xd7,0xb0,0x25,0xac,0xaf,0x12,0x03,0xe2,0xf2
};

/* D常量 (15位值) */
static const uint16_t D[16] = {
    0x44D7, 0x26BC, 0x626B, 0x135E, 0x5789, 0x35E2, 0x7135, 0x09AF,
    0x4D78, 0x2F13, 0x6BC4, 0x1AF1, 0x5E26, 0x3C4D, 0x789A, 0x47AC
};

// ===================== 高性能 AVX2 S-Box 核心 =====================
static uint32_t S0_32bit[256] __attribute__((aligned(32)));
static uint32_t S1_32bit[256] __attribute__((aligned(32)));

// 初始化 S-Box 數據 (應在首次 init 時調用)
static void init_sbox_data_avx2() {
    for (int i = 0; i < 256; i++) {
        S0_32bit[i] = S0[i];
        S1_32bit[i] = S1[i];
    }
}

// 使用高性能 S-Box 函數
static inline void process_sbox_avx2(__m256i u_in, __m256i v_in, __m256i* sbox_u_out, __m256i* sbox_v_out) {
    const __m256i MASK_FF = _mm256_set1_epi32(0x000000FF);

    // --- 處理 U 向量 ---
    __m256i u_b3_indices = _mm256_srli_epi32(u_in, 24);
    __m256i u_b2_indices = _mm256_and_si256(_mm256_srli_epi32(u_in, 16), MASK_FF);
    __m256i u_b1_indices = _mm256_and_si256(_mm256_srli_epi32(u_in, 8), MASK_FF);
    __m256i u_b0_indices = _mm256_and_si256(u_in, MASK_FF);

    __m256i u_b3_prime = _mm256_i32gather_epi32((const int*)S0_32bit, u_b3_indices, 4);
    __m256i u_b2_prime = _mm256_i32gather_epi32((const int*)S1_32bit, u_b2_indices, 4);
    __m256i u_b1_prime = _mm256_i32gather_epi32((const int*)S0_32bit, u_b1_indices, 4);
    __m256i u_b0_prime = _mm256_i32gather_epi32((const int*)S1_32bit, u_b0_indices, 4);

    *sbox_u_out = _mm256_or_si256(
                    _mm256_or_si256(u_b0_prime, _mm256_slli_epi32(u_b1_prime, 8)),
                    _mm256_or_si256(_mm256_slli_epi32(u_b2_prime, 16), _mm256_slli_epi32(u_b3_prime, 24))
                  );

    // --- 處理 V 向量 ---
    __m256i v_b3_indices = _mm256_srli_epi32(v_in, 24);
    __m256i v_b2_indices = _mm256_and_si256(_mm256_srli_epi32(v_in, 16), MASK_FF);
    __m256i v_b1_indices = _mm256_and_si256(_mm256_srli_epi32(v_in, 8), MASK_FF);
    __m256i v_b0_indices = _mm256_and_si256(v_in, MASK_FF);

    __m256i v_b3_prime = _mm256_i32gather_epi32((const int*)S0_32bit, v_b3_indices, 4);
    __m256i v_b2_prime = _mm256_i32gather_epi32((const int*)S1_32bit, v_b2_indices, 4);
    __m256i v_b1_prime = _mm256_i32gather_epi32((const int*)S0_32bit, v_b1_indices, 4);
    __m256i v_b0_prime = _mm256_i32gather_epi32((const int*)S1_32bit, v_b0_indices, 4);

    *sbox_v_out = _mm256_or_si256(
                    _mm256_or_si256(v_b0_prime, _mm256_slli_epi32(v_b1_prime, 8)),
                    _mm256_or_si256(_mm256_slli_epi32(v_b2_prime, 16), _mm256_slli_epi32(v_b3_prime, 24))
                  );
}


// ===================== 輔助函數  =====================
static inline __m256i rotl32_avx2(__m256i x, int n) {
    return _mm256_or_si256(_mm256_slli_epi32(x, n), _mm256_srli_epi32(x, 32 - n));
}

static inline __m256i rotl31_avx2(__m256i x, uint32_t n) {
    n %= 31;
    const __m256i mask = _mm256_set1_epi32(0x7FFFFFFF);
    x = _mm256_and_si256(x, mask);
    __m256i left = _mm256_slli_epi32(x, n);
    __m256i right = _mm256_srli_epi32(x, 31 - n);
    return _mm256_and_si256(_mm256_or_si256(left, right), mask);
}

static inline __m256i mod_add31_avx2(__m256i a, __m256i b) {
    __m256i sum = _mm256_add_epi32(a, b);
    __m256i hi = _mm256_srli_epi32(sum, 31);
    __m256i lo = _mm256_and_si256(sum, _mm256_set1_epi32(0x7FFFFFFF));
    return _mm256_add_epi32(lo, hi);
}

static inline __m256i L1_avx2(__m256i x) {
    __m256i rot_x2 = rotl32_avx2(x, 2);
    __m256i rot_x10 = rotl32_avx2(x, 10);
    __m256i rot_x18 = rotl32_avx2(x, 18);
    __m256i rot_x24 = rotl32_avx2(x, 24);
    __m256i result = _mm256_xor_si256(x, rot_x2);
    result = _mm256_xor_si256(result, rot_x10);
    result = _mm256_xor_si256(result, rot_x18);
    result = _mm256_xor_si256(result, rot_x24);
    return result;
}

static inline __m256i L2_avx2(__m256i x) {
    __m256i rot_x8 = rotl32_avx2(x, 8);
    __m256i rot_x14 = rotl32_avx2(x, 14);
    __m256i rot_x22 = rotl32_avx2(x, 22);
    __m256i rot_x30 = rotl32_avx2(x, 30);
    __m256i result = _mm256_xor_si256(x, rot_x8);
    result = _mm256_xor_si256(result, rot_x14);
    result = _mm256_xor_si256(result, rot_x22);
    result = _mm256_xor_si256(result, rot_x30);
    return result;
}

// ===================== 核心算法實現  =====================
static void zuc_step_8ch(zuc_state_8ch* state, __m256i* W_out, __m256i* X3_out) {
    // 位重組 (Bit Reorganization) 
    __m256i X0 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(state->lfsr[15], _mm256_set1_epi32(0x7FFF8000)), 1), _mm256_and_si256(state->lfsr[14], _mm256_set1_epi32(0xFFFF)));
    __m256i X1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(state->lfsr[11], _mm256_set1_epi32(0xFFFF)), 16), _mm256_srli_epi32(state->lfsr[9], 15));
    __m256i X2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(state->lfsr[7], _mm256_set1_epi32(0xFFFF)), 16), _mm256_srli_epi32(state->lfsr[5], 15));
    __m256i X3_val = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(state->lfsr[2], _mm256_set1_epi32(0xFFFF)), 16), _mm256_srli_epi32(state->lfsr[0], 15));
    
    // F函數
    __m256i W_val = _mm256_add_epi32(_mm256_xor_si256(X0, state->R1), state->R2);
    __m256i W1 = _mm256_add_epi32(state->R1, X1);
    __m256i W2 = _mm256_xor_si256(state->R2, X2);
    __m256i u = _mm256_or_si256(_mm256_slli_epi32(W1, 16), _mm256_srli_epi32(W2, 16));
    __m256i v = _mm256_or_si256(_mm256_slli_epi32(W2, 16), _mm256_srli_epi32(W1, 16));
    u = L1_avx2(u);
    v = L2_avx2(v);
    
    // 調用高性能 S-Box
    process_sbox_avx2(u, v, &state->R1, &state->R2);
    
    // LFSR 更新 
    __m256i v_sum = _mm256_setzero_si256(); 
    v_sum = mod_add31_avx2(v_sum, rotl31_avx2(state->lfsr[15], 15));
    v_sum = mod_add31_avx2(v_sum, rotl31_avx2(state->lfsr[13], 17));
    v_sum = mod_add31_avx2(v_sum, rotl31_avx2(state->lfsr[10], 21));
    v_sum = mod_add31_avx2(v_sum, rotl31_avx2(state->lfsr[4], 20));
    v_sum = mod_add31_avx2(v_sum, rotl31_avx2(state->lfsr[0], 8));
    v_sum = mod_add31_avx2(v_sum, state->lfsr[0]);
    
    if (state->is_init_mode) {
        __m256i W_half = _mm256_srli_epi32(W_val, 1);
        v_sum = mod_add31_avx2(v_sum, W_half);
    }
    
    memmove(state->lfsr, state->lfsr + 1, 15 * sizeof(__m256i));
    state->lfsr[15] = v_sum;

    if (W_out) *W_out = W_val;
    if (X3_out) *X3_out = X3_val;
}

// 初始化8個ZUC實例
void zuc_init_8ch(zuc_state_8ch* state, const uint8_t keys[8][16], const uint8_t ivs[8][16]) {
    // S-Box數據一次性初始化
    static int sbox_data_initialized = 0;
    if (!sbox_data_initialized) {
        init_sbox_data_avx2();
        sbox_data_initialized = 1;
    }

    memcpy(state->keys, keys, sizeof(state->keys));
    memcpy(state->ivs, ivs, sizeof(state->ivs));
    
    // LFSR 初始化
    uint32_t lfsr_vals[8] __attribute__((aligned(32))); 
    for (int i = 0; i < 16; i++) {
        for (int ch = 0; ch < 8; ch++) {
            lfsr_vals[ch] = ((uint32_t)keys[ch][i] << 23) | 
                           ((uint32_t)D[i] << 8) | 
                           ivs[ch][i];
        }
        state->lfsr[i] = _mm256_load_si256((__m256i*)lfsr_vals);
    }
    
    state->R1 = _mm256_setzero_si256();
    state->R2 = _mm256_setzero_si256();
    state->is_init_mode = 1;
    
    for (int i = 0; i < 32; i++) {
        zuc_step_8ch(state, NULL, NULL);
    }

    state->is_init_mode = 0;
    state->discard_initial_output = 0; 
}

// 生成8通道密鑰流
void zuc_generate_8ch(zuc_state_8ch* state, uint32_t output[8]) {

    if (state->discard_initial_output == 0) {
        zuc_step_8ch(state, NULL, NULL);
        state->discard_initial_output = 1;
    }

    __m256i W_val, X3_val;
    zuc_step_8ch(state, &W_val, &X3_val);
    
    __m256i out_vec = _mm256_xor_si256(W_val, X3_val);
    _mm256_storeu_si256((__m256i*)output, out_vec);
}

// 清理狀態
void zuc_clear_8ch(zuc_state_8ch* state) {
    memset(state, 0, sizeof(zuc_state_8ch));
}
